#!/usr/bin/python3

import os
import sys
import time

import pickle

import pynlpir
from pynlpir import nlpir

from hanzi_util import is_zh_number,is_terminator,is_punct,is_zh,is_zhs
import hanzi_prep

import pinyin

import threading
import socket
import queue

import ctypes
from ctypes import c_char_p, c_wchar_p

# Server Side Configuration
HOST = "192.168.1.163"
PORT = 10240

FILE_NAME = "./data/jd.dat"
#分词处理
FILE_NAME_JIEBA_PK = FILE_NAME + "_JIEBA_PK"
FILE_NAME_JIEBA_PINYIN = FILE_NAME + "_JIEBA_PINYIN"

JIEBA_HZ = {}
if not os.path.exists(FILE_NAME_JIEBA_PK):
    print("请计算产生词频数据!")
    os.exit()
else:
    print("加载JIEBA词频信息")
    with open(FILE_NAME_JIEBA_PK, 'rb') as fin:
        JIEBA_HZ = pickle.load(fin)

JIEBA_PINYIN = {}
if not os.path.exists(FILE_NAME_JIEBA_PINYIN):
    print("请计算产生拼音信息!")
    os.exit()
else:
    print("加载JIEBA拼音信息")
    with open(FILE_NAME_JIEBA_PINYIN, 'rb') as fin:
        JIEBA_PINYIN = pickle.load(fin)


def correct_me(str_test):
#    str_len = len(str_test)
#    print("\n==单字测试==")
#    for i in range(1,str_len):
#        tmp_str = str_test[i-1] + str_test[i]
#        if is_terminator(str_test[i]):
#            tmp_str = str_test[i-1] + '</s>'
#        if is_terminator(str_test[i-1]):
#            tmp_str = '<s>' + str_test[i]
#        if UNIC_HZ.get(tmp_str):
#            print("%s->%f" % (tmp_str,UNIC_HZ.get(tmp_str)),end="\t")
#        else:
#            print("%s->%f" % (tmp_str,0),end="\t")


    print("")
    print("==NLPIR分词==")
    print("测试语句：%s" %(str_test))
    line_p = hanzi_prep.split_into_sentences(str_test)
    lines = []
    for line_i in line_p:
        lines.extend(line_i)
    str_i = ''.join(lines)
    #jieba_i = ' '.join(jieba.cut(str_i, cut_all=False))
    jieba_i = ' '.join(pynlpir.segment(str_i, pos_tagging=False))
    print("分词结果:%s"%(repr(jieba_i)))
    jieba_i = jieba_i.split()
    jieba_len = len(jieba_i)
    if jieba_len < 3:
        print("词数太小，放弃纠错!")
        return
    jieba_key = []
    jieba_pro = []
    for i in range(1,jieba_len):

        #是否是标点符号
#        if i == 0:
#            tmp_str = '<s>' + jieba_i[i]
#        if i == jieba_len -1:
#            tmp_str = jieba_i[i] + '</s>'
#        else:
#            #默认模式
#            tmp_str = jieba_i[i-1] + jieba_i[i]
#            if len(jieba_i[i]) == 1:
#                if is_terminator(jieba_i[i]):
#                    tmp_str = jieba_i[i-1] + '</s>'
#            if len(jieba_i[i-1]) == 1:
#                if is_terminator(jieba_i[i-1]):
#                    tmp_str = '<s>' + jieba_i[i]
        #不考虑开头结尾模式
        tmp_str = jieba_i[i-1] + jieba_i[i]
        pro = JIEBA_HZ.get(tmp_str)

        jieba_key.append(tmp_str)
        if pro:
            jieba_pro.append(pro)
        else:
            jieba_pro.append(0)

#    if min_index != -1:
#        print("\n可能错误位置:",end="")
#        if min_index > 1:
#            print("%s"%jieba_i[min_index-1],end="")
#        print("%s"%jieba_i[min_index])
#        if min_index < (jieba_len - 1):
#            print("%s"%jieba_i[min_index+1],end="")
    print("分词表:"+repr(jieba_key))
    print("概率表:"+repr(jieba_pro))

    jieba_pro_t = []
    for i in range(0,jieba_len-2):
        jieba_pro_t.append( jieba_pro[i] + jieba_pro[i+1])

    min_index = jieba_pro_t.index(min(jieba_pro_t)) + 1
    print("可疑位置:[%d]->%s"%(min_index,jieba_i[min_index]))
    to_do = []
    g_check_a = None
    g_check_e = None
    #纠错位置不可能在开头或者结尾
    to_do.append(jieba_i[min_index-1])
    to_do.append(jieba_i[min_index])
    to_do.append(jieba_i[min_index+1])
    if min_index - 2 >= 0:
        g_check_a = jieba_i[min_index-2]
    if min_index + 2 < jieba_len:
        g_check_e = jieba_i[min_index+2]

    print("需要处理:"+repr(to_do))
    print("辅助检测:%s,%s" %(g_check_a, g_check_e))

    #保存最终的结果
    p_res_stage1 = {}
    p_res_stage2 = {}
    p_res_stage3 = {}
    max_item_1 = None
    max_item_2 = None
    max_item_3 = None
    max_pro_1 = 0
    max_item_1 = None
    max_pro_2 = 0
    max_item_2 = None
    max_pro_3 = 0
    max_item_3 = None

    #STAGE1 假设分词没有错误
    pinyin_t = pinyin.word2pinyin_split(to_do[0],'-') + '-' + pinyin.word2pinyin_split(to_do[1],'-')
    p_res_1 = {}
    if pinyin_t in JIEBA_PINYIN.keys():
        list_t = JIEBA_PINYIN.get(pinyin_t)
        for item in list_t:
            if to_do[0] != item[0:len(to_do[0])]:
                continue
            else:
                p_res_1[item[len(to_do[0]):]] = JIEBA_HZ.get(item)

    pinyin_t = pinyin.word2pinyin_split(to_do[1],'-') + '-' + pinyin.word2pinyin_split(to_do[2],'-')
    p_res_2 = {}
    if pinyin_t in JIEBA_PINYIN.keys():
        list_t = JIEBA_PINYIN.get(pinyin_t)
        for item in list_t:
            if to_do[2] != item[len(to_do[1]):]:
                continue
            else:
                p_res_2[item[0:len(to_do[1])]] = JIEBA_HZ.get(item)
                #print("2.找到:%s-%s,概率%f\t" %(to_do[0],item,JIEBA_HZ.get(item)))

    p_res_intr = dict.fromkeys(x for x in p_res_1 if x in p_res_2)
    if p_res_intr:
        max_pro_1 = 0
        max_item_1 = None
        for item in p_res_intr:
            p_res_intr[item] = p_res_1[item]*p_res_2[item] / (p_res_1[item] + p_res_2[item])
            if p_res_intr[item] > max_pro_1:
                max_pro_1 = p_res_intr[item]
                max_item_1 = item
        print(repr(p_res_intr))
        p_res_stage1 = p_res_intr

    #STAGE2 假设第一和第二个合并
    to_do_a = [to_do[0]+to_do[1], to_do[2]]
    p_res_3 = {}
    p_res_s3 = {}
    pinyin_t = pinyin.word2pinyin_split(to_do_a[0],'-') + '-' + pinyin.word2pinyin_split(to_do_a[1],'-')
    if pinyin_t in JIEBA_PINYIN.keys():
        list_t = JIEBA_PINYIN.get(pinyin_t)
        for item in list_t:
            print(item)
            if to_do_a[1] != item[len(to_do_a[1]):]:
                continue
            else:
                p_res_3[item[:len(to_do_a[1])]] = JIEBA_HZ.get(item)

    if g_check_a:
        for item in p_res_3:
            item_t = g_check_a+item
            if item_t in JIEBA_HZ.keys():
                p_res_s3[item] = JIEBA_HZ.get(item_t)
    else:
        p_res_s3 = p_res_3

    p_res_intr = dict.fromkeys(x for x in p_res_3 if x in p_res_s3)
    if p_res_intr:
        for item in p_res_intr:
            p_res_intr[item] = p_res_3[item]*p_res_s3[item] / (p_res_3[item] + p_res_s3[item])
            if p_res_intr[item] > max_pro_2:
                max_pro_2 = p_res_intr[item]
                max_item_2 = item
        p_res_stage2 = p_res_intr

#STAGE3 假设第二和第三个合并
    to_do_b = [to_do[0], to_do[1]+to_do[2]]
    p_res_4 = {}
    p_res_s4 = {}
    pinyin_t = pinyin.word2pinyin_split(to_do_b[0],'-') + '-' + pinyin.word2pinyin_split(to_do_b[1],'-')
    if pinyin_t in JIEBA_PINYIN.keys():
        list_t = JIEBA_PINYIN.get(pinyin_t)
        for item in list_t:
            if to_do_b[0] != item[0:len(to_do_b[0])]:
                continue
            else:
                p_res_4[item[len(to_do_b[0]):]] = JIEBA_HZ.get(item)

    if g_check_e:
        for item in p_res_4:
            item_t = item + g_check_e
            if item_t in JIEBA_HZ.keys():
                p_res_s4[item] = JIEBA_HZ.get(item_t)
    else:
        p_res_s4 = p_res_4

    p_res_intr = dict.fromkeys(x for x in p_res_4 if x in p_res_s4)
    if p_res_intr:
        for item in p_res_intr:
            p_res_intr[item] = p_res_4[item]*p_res_s4[item] / (p_res_4[item] + p_res_s4[item])
            if p_res_intr[item] > max_pro_3:
                max_pro_3 = p_res_intr[item]
                max_item_3 = item
        print(repr(p_res_intr))
        p_res_stage3 = p_res_intr


    #打印纠正结果
    if max_item_1:
        print("STAGE1:纠错结果：%s %s %s，概率%f"%(to_do[0],max_item_1,to_do[2],p_res_stage1[max_item_1]))
    else:
        print("STAGE1:纠错失败")
    if max_item_2:
        print("STAGE2:纠错结果：%s %s，概率%f"%(max_item_2,to_do_a[1],p_res_stage2[max_item_2]))
    else:
        print("STAGE2:纠错失败")
    if max_item_3:
        print("STAGE3:纠错结果：%s %s，概率%f"%(to_do_b[0],max_item_3,p_res_stage3[max_item_3]))
    else:
        print("STAGE3:纠错失败")

    max_pro = max([max_pro_1, max_pro_2, max_pro_3])
    if max_pro != 0:
        if max_pro == max_pro_1:
            final_words = jieba_i[0:min_index-1] + [ to_do[0], max_item_1, to_do[2] ] + jieba_i[min_index+2:jieba_len]
        elif max_pro == max_pro_2:
            final_words = jieba_i[0:min_index-1] + [ max_item_2, to_do_a[1] ] + jieba_i[min_index+2:jieba_len]
        elif max_pro == max_pro_3:
            final_words = jieba_i[0:min_index-1] + [ to_do_b[0], max_item_3 ] + jieba_i[min_index+2:jieba_len]

        print("原句: "+str_test)
        print("纠正："+''.join(final_words))
        return (''.join(final_words))
    else:
        print('纠错失败')
        return None

# Max buffer size
global_q = queue.Queue(maxsize=10)


# 由DistributeThread分发，用于处理客户端请求
class ProcessPoolThread(threading.Thread):
    def __init__(self, tid):
        self.cur_thread = threading.Thread.__init__(self)
        self.threadID = tid
        self.pynlpir = pynlpir
    
    def run(self):
        while True:
            # The Queue get是阻塞的
            conn, addr = global_q.get()
            self.handle_process(conn, addr)

    def handle_process(self, conn, addr):
        # 多个返回可能会被合并到一个数据包中
        print("PROCESS T[%d]正在处理%s" %(self.threadID, repr(addr)))
        req_datas = conn.recv(2048).decode()
        jreq_datas = eval(req_datas)
        for req_item in jreq_datas:
            if req_item['CLIENT'] != 0 and req_item['TYPE'] == 'REQ_COR':
                rep_data = correct_me(req_item['DATA'])
                rep_url = {'CLIENT':req_item['CLIENT'],'TYPE':'REP_URL','DATA':rep_data}
                jrep_url = repr(rep_url) + ','
                conn.sendall(jrep_url.encode())        
            else:
                print("UKNOWN CLIENT REQUEST!")
                
        return


class DistributeThread(threading.Thread):
    def __init__(self, host, port, tid):
        threading.Thread.__init__(self)   
        self.host = host
        self.port = port
        self.threadID = tid
        self.sk = socket.socket(socket.AF_INET,socket.SOCK_STREAM)         

    def run(self):
        print ("启动服务器分发线程 %d ...\n" % self.threadID)
        self.sk.bind((self.host,self.port))
        self.sk.listen(1)
    
        while True:
            conn, addr = self.sk.accept()
            print("SERVER[%d], Recv request from %s" % (self.threadID, addr))
            global_q.put((conn, addr))
            
        print ("我怎么可能退出。。。。")
        return        
        
if __name__ == '__main__':

    #pynlpir.open()
    dict_file = c_char_p(b'MY_DICT.dat')
    print(type(dict_file))
    print(type(pynlpir.nlpir.ImportUserDict))

    #用户自定义词典
    nlpir.Init(c_char_p(b'/home/user/Study/MachLearn/sogou_txt/utf-8/pynlpir/pynlpir'), nlpir.UTF8_CODE, None)
    nlpir.ImportUserDict(dict_file)

    if len(sys.argv) > 1:
        for i in range(1, len(sys.argv)):
            correct_me(sys.argv[i])
    else:
        correct_me("到乌鲁木齐邮飞怎么算")
        correct_me("我想要一个双卡双待的手机")
    #    correct_me("仓库在钓鱼道吗难道")
    #correct_me("你们从那里发货")
    #    correct_me("你们在市区那里发货")
    #    correct_me("和原来那个友什么区别")


    ss = bytes("我想要一个双卡双待的手机", encoding='utf-8')
    s = c_char_p(ss)
    size = ctypes.c_int()
    result = nlpir.ParagraphProcessA(s, ctypes.byref(size), True)
    print(repr(result))
    result_t_vector = ctypes.cast(result, ctypes.POINTER(nlpir.ResultT))
    words = []
    for i in range(0, size.value):
        r = result_t_vector[i]
        word = s[r.start:r.start+r.length]
        words.append((word, r.sPOS))

    print(words)


    nlpir.Exit()
    #pynlpir.close()

    print("你好,起已退出完毕")
